package com.ming.common.liteflow.core.graph;

import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import com.ming.common.beetl.util.StrUtil;
import com.ming.common.util.CommonUtil;
import com.ming.common.liteflow.core.json.ELJsonUtil;
import com.ming.common.liteflow.core.node.IvyCmp;
import com.ming.common.liteflow.core.node.NodeInfoWrapper;
import com.yomahub.liteflow.builder.el.WhenELWrapper;
import com.yomahub.liteflow.enums.NodeTypeEnum;
import lombok.Data;

import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;

@Data
public class GraphEL {
    private Map<Node, List<Node>> list;//正序
    private List<Edge> edgeList;
    private Map<Node, List<Node>> reverseList;//倒序
    private Map<Long, IvyCmp> nodeInfoMap;
    private List<Node> groupParallelList;
    private List<Node> preList;
    private List<Node> finallyList;
    private List<Node> fallbackList;

    private Node startNode;
    private List<Node> startNodeList;
    private List<Node> endNode;

    public GraphEL() {
        this.list = new LinkedHashMap<>();
        this.reverseList = new LinkedHashMap<>();
    }

    public static GraphEL getGraphEL(LogicFlowData logicFlowData){
        return getGraphEL(logicFlowData,logicFlowData.getIvyCmpMap());
    }

    public static GraphEL getGraphEL(LogicFlowData logicFlowData,Map<Long,IvyCmp> nodeInfoMap){
        GraphEL graph = new GraphEL();
        graph.setNodeInfoMap(nodeInfoMap);
        graph.setGroupParallelList(logicFlowData.getGroupParallelList());
        graph.setPreList(handlerPreFinally(logicFlowData.getPreList(),nodeInfoMap));
        graph.setFinallyList(handlerPreFinally(logicFlowData.getFinallyList(),nodeInfoMap));
        graph.setFallbackList(handlerFallback(logicFlowData.getFallbackList(),nodeInfoMap));
        List<Node> nodes = logicFlowData.getNodes();
        List<Node> startNodeList = new ArrayList<>();
        Map<String, Node> nodeMap = nodes.stream().collect(Collectors.toMap(Node::getId, m -> m));
        for (Node node : nodes) {
            graph.addNode(node);
        }
        List<Edge> edges = logicFlowData.getEdges();
        graph.setEdgeList(edges);
        Set<String> targetNodes = new HashSet<>();
        for (Edge edge : edges) {
            graph.addEdge(nodeMap.get(edge.getSourceNodeId()), nodeMap.get(edge.getTargetNodeId()));
            targetNodes.add(edge.getTargetNodeId());
        }
        if(!edges.isEmpty()){
            for (Edge edge : edges) {
                String sourceNodeId = edge.getSourceNodeId();
                if (!targetNodes.contains(sourceNodeId)) {
                    Node startNode = nodeMap.get(sourceNodeId);
                    graph.setStartNode(startNode);
                    if(!startNodeList.contains(startNode)){
                        startNodeList.add(startNode);
                    }
                }
            }
        }else{
            graph.setStartNode(nodes.get(0));
        }

        graph.setEndNode(graph.getList().entrySet().stream()
                .filter(entry -> entry.getValue().isEmpty())
                .map(Map.Entry::getKey)
                .collect(Collectors.toList()));
        graph.setStartNodeList(startNodeList);
        return graph;
    }

    private static List<Node> handlerFallback(List<Node> nodes,Map<Long,IvyCmp> nodeInfoMap) {
        if(nodes != null && !nodes.isEmpty()){
            for (Node node : nodes){
                NodeInfoWrapper prop = node.getProperties();
                if(prop.getFallbackCommonId() != null){
                    prop.setFallbackType(NodeTypeEnum.COMMON.getCode());
                }else if(prop.getFallbackSwitchId() != null){
                    prop.setFallbackType(NodeTypeEnum.SWITCH.getCode());
                }else if(prop.getFallbackIfId() != null){
                    prop.setFallbackType(NodeTypeEnum.IF.getCode());
                }else if(prop.getFallbackForId() != null){
                    prop.setFallbackType(NodeTypeEnum.FOR.getCode());
                }else if(prop.getFallbackWhileId() != null){
                    prop.setFallbackType(NodeTypeEnum.WHILE.getCode());
                }else if(prop.getFallbackBreakId() != null){
                    prop.setFallbackType(NodeTypeEnum.BREAK.getCode());
                }else if(prop.getFallbackIteratorId() != null){
                    prop.setFallbackType(NodeTypeEnum.ITERATOR.getCode());
                }
            }
        }
        return nodes;
    }

    private static List<Node> handlerPreFinally(List<Node> nodes,Map<Long,IvyCmp> nodeInfoMap) {
        if(nodes == null || nodes.isEmpty()){
            return null;
        }
        List<Node> nodeList = new ArrayList<>();
        for (Node node : nodes){
            NodeInfoWrapper properties = node.getProperties();
            if(properties != null){
                String[] ids = properties.getIds();
                if(ids != null){
                    for (String id : ids){
                        NodeInfoWrapper nodeInfoWrapper = new NodeInfoWrapper();
                        BeanUtil.copyProperties(nodeInfoMap.get(Long.parseLong(id)), nodeInfoWrapper);
                        Node n = new Node();
                        n.setId(node.getId());
                        n.setType(node.getType());
                        n.setText(n.getText());
                        n.setProperties(nodeInfoWrapper);
                        nodeList.add(n);
                    }
                }
            }
        }
        return nodeList;
    }

    public void addNode(Node node) {
        list.put(node, new ArrayList<>());
        reverseList.put(node, new ArrayList<>());
    }

    public void addEdge(Node sourceNode, Node targetNode) {
        list.get(sourceNode).add(targetNode);
        reverseList.get(targetNode).add(sourceNode);
    }

    public boolean isAllThen(){
        if(list.size() == 1){ return true; }
        return getAllPaths(startNode).size() == 1;
    }

    public boolean isThen(Node currNode, Node forkNode) {
        List<List<Node>> list = getAllPaths(currNode, forkNode, false);
        return list.size() == 1;
    }

    public boolean isThen(Node currNode){
        List<List<Node>> list = getAllPaths(currNode, false);
        return list.size() == 1;
    }

    public boolean isAllWhen(){
        return list.values().stream().allMatch(m -> m.size() == 0);
    }

    public boolean isWhen(Node currNode, Node forkNode) {
        List<List<Node>> list = getAllPaths(currNode, forkNode, false);
        return list.size() > 1;
    }

    public boolean isEndNode(Node currNode) {
        return endNode.contains(currNode);
    }

    public boolean isEndNode(Node currNode, Node endNode) {
        return currNode == endNode;
    }

    // 查找节点的邻居
    public List<Node> nextNode(Node node) {
        return list.get(node);
    }

    public List<Node> prevNode(Node node) {
        return reverseList.get(node);
    }

    // 当前节点 -> 分叉节点
    public Node getForkNode(Node currNode){
        List<Node> nodeList = list.get(currNode);
        if(nodeList.size() == 1 && reverseList.get(nodeList.get(0)).size() > 1){
            return currNode;
        }
        if(nodeList.size() == 1){
            return getForkNode(nodeList.get(0));
        }
        return currNode;
    }

    // 分叉节点 -> 聚合节点
    public Node getJoinNode(List<Node> nodeList) {
        List<List<Node>> allPaths = new ArrayList<>();
        for (Node node : nodeList){
            allPaths.addAll(getAllPaths(node));
        }
        Set<Node> commonNodes = new LinkedHashSet<>(allPaths.get(0));
        for (List<Node> path : allPaths) {
            commonNodes.retainAll(path);
        }
        return commonNodes.stream().findFirst().orElse(null);
    }

    public Node getJoinNode(Node node) {
        List<Node> joinNodes = getCommonNodesInAllPaths(node);
        if(CommonUtil.collUtil.isNotEmpty(joinNodes)){
            return joinNodes.get(0);
        }
        return null;
    }

    public Node getJoinNode(Node node, List<Node> excludeList) {
        List<Node> joinNodes = getCommonNodesInAllPaths(node, excludeList);
        if(CommonUtil.collUtil.isNotEmpty(joinNodes)){
            return joinNodes.get(0);
        }
        return null;
    }


    // 在所有路径中找到在每个路径中都存在的节点
    public List<Node> getCommonNodesInAllPaths(Node startNode, List<Node> excludeList) {
        List<List<Node>> allPaths = getAllPaths(startNode);
        if(CollUtil.isNotEmpty(excludeList)){
            allPaths = allPaths.stream()
                    .filter(subList -> subList.stream().noneMatch(excludeList::contains))
                    .collect(Collectors.toList());
        }
        Set<Node> commonNodes = new LinkedHashSet<>(allPaths.get(0));
        Set<Node> excludeNodes = new HashSet<>();
        for (List<Node> path : allPaths) {
            commonNodes.retainAll(path);
        }
        // 移除当前节点
        excludeNodes.add(startNode);
        // 移除分叉节点
        Set<Node> nodeSet = this.getList().entrySet().stream()
                .filter(entry -> entry.getValue().size() > 1 && this.getReverseList().get(entry.getKey()).size() <= 1)
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());
        // 移除不相交节点


        commonNodes.removeAll(excludeNodes);
        commonNodes.removeAll(nodeSet);
        return new ArrayList<>(commonNodes);
    }

    public List<Node> getCommonNodesInAllPaths(Node startNode) {
        List<List<Node>> allPaths = getAllPaths(startNode);
        Set<Node> commonNodes = new LinkedHashSet<>(allPaths.get(0));
        Set<Node> excludeNodes = new HashSet<>();
        for (List<Node> path : allPaths) {
            commonNodes.retainAll(path);
        }
        // 移除当前节点
        excludeNodes.add(startNode);
        // 移除分叉节点
        Set<Node> nodeSet = this.getList().entrySet().stream()
                .filter(entry -> entry.getValue().size() > 1 && this.getReverseList().get(entry.getKey()).size() <= 1)
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());

        commonNodes.removeAll(excludeNodes);
        commonNodes.removeAll(nodeSet);
        return new ArrayList<>(commonNodes);
    }

    public List<Node> getCommonNodesInAllPaths(Node startNode, Node endNode) {
        List<List<Node>> allPaths = getAllPaths(startNode, endNode, false);
        Set<Node> commonNodes = new LinkedHashSet<>(allPaths.get(0));
        Set<Node> excludeNodes = new HashSet<>();
        for (List<Node> path : allPaths) {
            commonNodes.retainAll(path);
        }
        // 移除当前节点
        excludeNodes.add(startNode);
        // 移除分叉节点
        Set<Node> nodeSet = this.getList().entrySet().stream()
                .filter(entry -> entry.getValue().size() > 1)
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());

        commonNodes.removeAll(excludeNodes);
        commonNodes.removeAll(nodeSet);
        return new ArrayList<>(commonNodes);
    }

    // 获取从给定节点开始的所有路径
    public List<List<Node>> getAllPaths(Node startNode, boolean excludeStartAndEnd) {
        List<List<Node>> allPaths = getAllPaths(startNode);
        return allPaths.stream().peek(m->m.remove(startNode)).collect(Collectors.toList());
    }

    public List<List<Node>> getAllPaths(Node startNode) {
        List<List<Node>> paths = new ArrayList<>();
        List<Node> currentPath = new ArrayList<>();
        Set<Node> visited = new HashSet<>();
        dfsGetAllPaths(startNode, currentPath, paths, visited);
        return paths;
    }

    // 获取从给定节点开始到结束节点之间的所有路径
    public List<List<Node>> getAllPaths(Node startNode, Node endNode, boolean excludeStartAndEnd) {
        List<List<Node>> paths = new ArrayList<>();
        List<Node> currentPath = new ArrayList<>();
        Set<Node> visited = new HashSet<>();
        dfsGetAllPaths(startNode, endNode, currentPath, paths, visited);
        if(excludeStartAndEnd){
            return paths.stream()
                    .map(list -> list.stream()
                            .filter(node -> !node.getId().equals(startNode.getId()) && !node.getId().equals(endNode.getId()))
                            .collect(Collectors.toList()))
                    .collect(Collectors.toList());
        }
        return paths;
    }

    //0:不排除，1排除开始节点，2：排除结束节点，3全部排除
    public List<List<Node>> getAllPaths(Node startNode, Node endNode, int excludeStartAndEnd) {
        List<List<Node>> paths = new ArrayList<>();
        List<Node> currentPath = new ArrayList<>();
        Set<Node> visited = new HashSet<>();
        dfsGetAllPaths(startNode, endNode, currentPath, paths, visited);
        switch (excludeStartAndEnd){
            case 3: return paths.stream().map(list -> list.stream()
                            .filter(node -> !node.getId().equals(startNode.getId()) && !node.getId().equals(endNode.getId()))
                            .collect(Collectors.toList())).collect(Collectors.toList());
            case 2: return paths.stream().map(list -> list.stream()
                    .filter(node -> !node.getId().equals(endNode.getId()))
                    .collect(Collectors.toList())).collect(Collectors.toList());
            case 1: return paths.stream().map(list -> list.stream()
                    .filter(node -> !node.getId().equals(startNode.getId()))
                    .collect(Collectors.toList())).collect(Collectors.toList());
            default: return paths;
        }
    }

    private void dfsGetAllPaths(Node currentNode, List<Node> currentPath, List<List<Node>> paths, Set<Node> visited) {
        currentPath.add(currentNode);
        visited.add(currentNode);

        if (nextNode(currentNode).isEmpty()) {
            // 当前节点是终点，将路径加入结果
            paths.add(new ArrayList<>(currentPath));
        } else {
            // 继续深度优先搜索
            for (Node nextNode : nextNode(currentNode)) {
                if (!visited.contains(nextNode)) {
                    dfsGetAllPaths(nextNode, currentPath, paths, visited);
                }
            }
        }

        // 回溯
        currentPath.remove(currentPath.size() - 1);
        visited.remove(currentNode);
    }

    private void dfsGetAllPaths(Node currentNode, Node endNode, List<Node> currentPath, List<List<Node>> paths, Set<Node> visited) {
        currentPath.add(currentNode);
        visited.add(currentNode);

        if (currentNode.equals(endNode)) {
            // 当前节点是结束节点，将路径加入结果
            paths.add(new ArrayList<>(currentPath));
        } else {
            // 继续深度优先搜索
            for (Node nextNode : nextNode(currentNode)) {
                if (!visited.contains(nextNode)) {
                    dfsGetAllPaths(nextNode, endNode, currentPath, paths, visited);
                }
            }
        }

        // 回溯
        currentPath.remove(currentPath.size() - 1);
        visited.remove(currentNode);
    }

    public List<List<List<Node>>> handlerPaths(List<List<Node>> allPaths){
        // 使用 Stream 过滤第一个相同元素数量大于1的子列表
        List<Node> repeatList = allPaths.stream()
                .map(list -> list.get(0)) // 获取每个子列表的第一个元素
                .collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))
                .entrySet().stream()
                .filter(entry -> entry.getValue() > 1)
                .map(Map.Entry::getKey)
                .collect(Collectors.toList());

        List<List<List<Node>>> groupList = new ArrayList<>();
        for (Node node : repeatList){
            List<List<Node>> listList = allPaths.stream().filter(m -> m.get(0).equals(node)).collect(Collectors.toList());
            groupList.add(listList);
        }
        List<List<Node>> listList = allPaths.stream().filter(m -> !repeatList.contains(m.get(0))).collect(Collectors.toList());
        listList.forEach(m->{
            List<List<Node>> list = new ArrayList<>();
            list.add(m);
            groupList.add(list);
        });
        return groupList;
    }

    public boolean isXNode(Node startNode) {
        return (this.list.get(startNode).size() > 1 && this.reverseList.get(startNode).size() > 1) ||
                this.list.get(startNode).size() > 1 && this.list.get(this.reverseList.get(startNode).get(0)).size() > 1
                ;
    }

    //
    public Node isGroupNode(Node startNode){
        if(CommonUtil.collUtil.isNotEmpty(groupParallelList)){
            for (Node node : groupParallelList){
                if(node.getChildren().contains(startNode.getId())){
                    return node;
                }
                List<Node> nodeList = list.get(startNode);
                Set<String> set = nodeList.stream().map(Node::getId).collect(Collectors.toSet());
                boolean flag = ELJsonUtil.retainAll(node.getChildren(), set);
                if(flag){
                    return node;
                }
            }
        }
        return null;
    }

    public void setGroupNodeProp(Node startNode,WhenELWrapper when){
        Node groupNode = isGroupNode(startNode);
        if(groupNode != null){
            NodeInfoWrapper properties = groupNode.getProperties();
            if(properties.getWhenIgnoreError() != null){
                when.ignoreError(properties.getWhenIgnoreError());
            }
            if(properties.getWhenAny() != null){
                if(properties.getWhenAny()){
                    when.any(properties.getWhenAny());
                }else{
                    String[] split = StrUtil.split(properties.getWhenMust());
                    if(split != null){
                        when.must(split);
                    }
                }
            }
        }
    }

    public EdgeProperties getEdgeProperties(Node startNode,Node endNode){
        if(CommonUtil.collUtil.isNotEmpty(edgeList)){
            for (Edge edge : edgeList){
                if(edge.getSourceNodeId().equals(startNode.getId()) && edge.getTargetNodeId().equals(endNode.getId())){
                    EdgeProperties edgeProperties = edge.getProperties();
                    if(edgeProperties != null && StrUtil.isEmpty(edgeProperties.getId()) && StrUtil.isEmpty(edgeProperties.getTag())){
                        edgeProperties.setId(endNode.getProperties().getComponentId());
                    }
                    return edgeProperties;
                }
            }
        }
        return null;
    }

//    public EdgeProperties getEdgeProperties(Node startNode){
//        if(CommonUtil.collUtil.isNotEmpty(edgeList)){
//            for (Edge edge : edgeList){
//                if(edge.getSourceNodeId().equals(startNode.getId())){
//                    return edge.getProperties();
//                }
//            }
//        }
//        return null;
//    }

    public boolean isCommonEdge(Node currNode) {
        if(CommonUtil.collUtil.isNotEmpty(edgeList)){
            for (Edge edge : edgeList){
                if(edge.getSourceNodeId().equals(currNode.getId())){
                    if(edge.getProperties().getLinkType() == 1){
                        return false;
                    }
                }
            }
        }
        return true;
    }

    public boolean isParentNode(Node childNode, Node parentNode) {
        List<Node> nodeList = reverseList.get(childNode);
        if(nodeList.contains(parentNode)){
            return true;
        }
        for (Node node : nodeList){
            boolean flag = isParentNode(node, parentNode);
            if(flag){
                return true;
            }
        }
        return false;
    }

    public Node prevJoinNode(Node node) {
        List<List<Node>> allPaths = new ArrayList<>();
        List<Node> currentPath = new ArrayList<>();
        Set<Node> visited = new HashSet<>();
        dfsGetAllPaths2(node, currentPath, allPaths, visited);

        // List<List<Node>> allPaths = getAllPaths(node);
        Set<Node> commonNodes = new LinkedHashSet<>(allPaths.get(0));
        Set<Node> excludeNodes = new HashSet<>();
        for (List<Node> path : allPaths) {
            commonNodes.retainAll(path);
        }
        // 移除当前节点
        excludeNodes.add(node);
        // 移除分叉节点
        Set<Node> nodeSet = reverseList.entrySet().stream()
                .filter(entry -> entry.getValue().size() > 1 && list.get(entry.getKey()).size() <= 1)
                .map(Map.Entry::getKey)
                .collect(Collectors.toSet());

        commonNodes.removeAll(excludeNodes);
        commonNodes.removeAll(nodeSet);
        List<Node> joinNodes = new ArrayList<>(commonNodes);
        if(CommonUtil.collUtil.isNotEmpty(joinNodes)){
            return joinNodes.get(0);
        }
        return null;
    }

    private void dfsGetAllPaths2(Node currentNode, List<Node> currentPath, List<List<Node>> paths, Set<Node> visited) {
        currentPath.add(currentNode);
        visited.add(currentNode);

        if (prevNode(currentNode).isEmpty()) {
            // 当前节点是终点，将路径加入结果
            paths.add(new ArrayList<>(currentPath));
        } else {
            // 继续深度优先搜索
            for (Node nextNode : prevNode(currentNode)) {
                if (!visited.contains(nextNode)) {
                    dfsGetAllPaths2(nextNode, currentPath, paths, visited);
                }
            }
        }

        // 回溯
        currentPath.remove(currentPath.size() - 1);
        visited.remove(currentNode);
    }

    // 是否包含复杂路径
    public boolean isComplexPath(Node currNode, Node joinNode) {
        List<List<Node>> allPaths = getAllPaths(currNode, joinNode, false);

        return false;
    }
}
